import torch

t1=torch.randn(6,6,6)
t1=torch.randn(8,8,8)
t2=torch.randn(2,4)
t3=t1*t2
print(t3.shape)
#11111
#22222